{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Tce3stUlHN0L"
   },
   "source": [
    "##### Copyright 2019 The TensorFlow Authors.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:16.495560Z",
     "iopub.status.busy": "2023-11-07T23:15:16.495018Z",
     "iopub.status.idle": "2023-11-07T23:15:16.499343Z",
     "shell.execute_reply": "2023-11-07T23:15:16.498703Z"
    },
    "id": "tuOe1ymfHZPu"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xHxb-dlhMIzW"
   },
   "source": [
    "## 概述\n",
    "\n",
    "本教程演示如何使用 Keras 模型和 <code>tf.distribute.Strategy</code> API 的<a>自定义训练循环</a>执行多工作进程分布式训练。训练循环通过 `tf.distribute.MultiWorkerMirroredStrategy` 进行分布。这样，设计为在[单个工作进程上运行的 `tf.keras` 模型](custom_training.ipynb)即可通过最少的代码更改无缝地在多个工作进程上运行。自定义训练循环提供了灵活性和更好的训练控制，同时也使模型的调试更加容易。请详细了解有关[编写基本训练循环](../../guide/basic_training_loops.ipynb)、 [从头开始编写训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch)和[自定义训练](../customization/custom_training_walkthrough.ipynb)的信息。\n",
    "\n",
    "如果您正在寻找如何将 `MultiWorkerMirroredStrategy` 与 `tf.keras.Model.fit` 一起使用，请参阅此[教程](multi_worker_with_keras.ipynb)。\n",
    "\n",
    "[TensorFlow 中的分布式训练](../../guide/distributed_training.ipynb)指南概述了 TensorFlow 支持的分布式策略，并适用于想要更深入了解 `tf.distribute.Strategy` API 的人。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MUXex9ctTuDB"
   },
   "source": [
    "## 安装\n",
    "\n",
    "首先，进行一些必要的导入。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:16.503236Z",
     "iopub.status.busy": "2023-11-07T23:15:16.502604Z",
     "iopub.status.idle": "2023-11-07T23:15:16.509179Z",
     "shell.execute_reply": "2023-11-07T23:15:16.508511Z"
    },
    "id": "bnYxvfLD-LW-"
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import sys"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Zz0EY91y3mxy"
   },
   "source": [
    "在导入 TensorFlow 之前，需要对环境进行一些变更：\n",
    "\n",
    "- 停用所有 GPU。这可以防止所有工作进程都尝试使用同一个 GPU 而导致的错误。对于真实应用，每个工作进程都将在不同的计算机上运行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:16.512523Z",
     "iopub.status.busy": "2023-11-07T23:15:16.511976Z",
     "iopub.status.idle": "2023-11-07T23:15:16.515646Z",
     "shell.execute_reply": "2023-11-07T23:15:16.515038Z"
    },
    "id": "685pbYEY3jGC"
   },
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7X1MS6385BWi"
   },
   "source": [
    "- 重置 `'TF_CONFIG'` 环境变量（稍后您将看到更多相关信息）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:16.519247Z",
     "iopub.status.busy": "2023-11-07T23:15:16.518695Z",
     "iopub.status.idle": "2023-11-07T23:15:16.521957Z",
     "shell.execute_reply": "2023-11-07T23:15:16.521355Z"
    },
    "id": "WEJLYa2_7OZF"
   },
   "outputs": [],
   "source": [
    "os.environ.pop('TF_CONFIG', None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Rd4L9Ii77SS8"
   },
   "source": [
    "- 确保当前目录位于 Python 的路径上。这样，笔记本可以导入稍后由 `%%writefile` 写入的文件。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:16.525511Z",
     "iopub.status.busy": "2023-11-07T23:15:16.524881Z",
     "iopub.status.idle": "2023-11-07T23:15:16.528298Z",
     "shell.execute_reply": "2023-11-07T23:15:16.527662Z"
    },
    "id": "hPBuZUNSZmrQ"
   },
   "outputs": [],
   "source": [
    "if '.' not in sys.path:\n",
    "  sys.path.insert(0, '.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pDhHuMjb7bfU"
   },
   "source": [
    "现在导入 TensorFlow。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:16.531882Z",
     "iopub.status.busy": "2023-11-07T23:15:16.531274Z",
     "iopub.status.idle": "2023-11-07T23:15:19.109502Z",
     "shell.execute_reply": "2023-11-07T23:15:19.108553Z"
    },
    "id": "vHNvttzV43sA"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:16.986830: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2023-11-07 23:15:16.986881: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2023-11-07 23:15:16.988638: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0S2jpf6Sx50i"
   },
   "source": [
    "### 数据集和模型定义"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fLW6D2TzvC-4"
   },
   "source": [
    "接下来，使用简单的模型和数据集设置创建 `mnist.py` 文件。本教程中的工作进程将使用此 Python 文件："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:19.114488Z",
     "iopub.status.busy": "2023-11-07T23:15:19.113982Z",
     "iopub.status.idle": "2023-11-07T23:15:19.120140Z",
     "shell.execute_reply": "2023-11-07T23:15:19.119342Z"
    },
    "id": "dma_wUAxZqo2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing mnist.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile mnist.py\n",
    "\n",
    "import os\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "def mnist_dataset(batch_size):\n",
    "  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()\n",
    "  # The `x` arrays are in uint8 and have values in the range [0, 255].\n",
    "  # You need to convert them to float32 with values in the range [0, 1]\n",
    "  x_train = x_train / np.float32(255)\n",
    "  y_train = y_train.astype(np.int64)\n",
    "  train_dataset = tf.data.Dataset.from_tensor_slices(\n",
    "      (x_train, y_train)).shuffle(60000)\n",
    "  return train_dataset\n",
    "\n",
    "def dataset_fn(global_batch_size, input_context):\n",
    "  batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n",
    "  dataset = mnist_dataset(batch_size)\n",
    "  dataset = dataset.shard(input_context.num_input_pipelines,\n",
    "                          input_context.input_pipeline_id)\n",
    "  dataset = dataset.batch(batch_size)\n",
    "  return dataset\n",
    "\n",
    "def build_cnn_model():\n",
    "  return tf.keras.Sequential([\n",
    "      tf.keras.Input(shape=(28, 28)),\n",
    "      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
    "      tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
    "      tf.keras.layers.Flatten(),\n",
    "      tf.keras.layers.Dense(128, activation='relu'),\n",
    "      tf.keras.layers.Dense(10)\n",
    "  ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JmgZwwymxqt5"
   },
   "source": [
    "## 多工作进程配置\n",
    "\n",
    "接下来，我们进入多工作进程训练的世界。在 TensorFlow 中，在多台计算机上进行训练需要 `'TF_CONFIG'` 环境变量。每台计算机可能有不同的角色。下面使用的 `'TF_CONFIG'` 变量是一个 JSON 字符串，它指定集群中每个工作进程的集群配置。这是使用 `cluster_resolver.TFConfigClusterResolver` 指定集群的默认方法，但在 `distribute.cluster_resolver` 模块中还有其他可用选项。请在[分布式训练指南](../../guide/distributed_training.ipynb)中了解有关设置 `'TF_CONFIG'` 变量的更多信息。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SS8WhvRhe_Ya"
   },
   "source": [
    "### 描述您的集群\n",
    "\n",
    "下面是一个示例配置："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:19.124437Z",
     "iopub.status.busy": "2023-11-07T23:15:19.123838Z",
     "iopub.status.idle": "2023-11-07T23:15:19.127351Z",
     "shell.execute_reply": "2023-11-07T23:15:19.126741Z"
    },
    "id": "XK1eTYvSZiX7"
   },
   "outputs": [],
   "source": [
    "tf_config = {\n",
    "    'cluster': {\n",
    "        'worker': ['localhost:12345', 'localhost:23456']\n",
    "    },\n",
    "    'task': {'type': 'worker', 'index': 0}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JjgwJbPKZkJL"
   },
   "source": [
    "请注意，`tf_config` 只是 Python 中的局部变量。要将其用于训练配置，请将其序列化为 JSON 并将其放在 `'TF_CONFIG'` 环境变量中。这是序列化为 JSON 字符串的相同 `'TF_CONFIG'`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:19.130890Z",
     "iopub.status.busy": "2023-11-07T23:15:19.130639Z",
     "iopub.status.idle": "2023-11-07T23:15:19.137618Z",
     "shell.execute_reply": "2023-11-07T23:15:19.136927Z"
    },
    "id": "yY-T0YDQZjbu"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{\"cluster\": {\"worker\": [\"localhost:12345\", \"localhost:23456\"]}, \"task\": {\"type\": \"worker\", \"index\": 0}}'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "json.dumps(tf_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AUBmYRZqxthH"
   },
   "source": [
    "`'TF_CONFIG'` 有两个组件：`'cluster'` 和 `'task'`。\n",
    "\n",
    "- `'cluster'` 对所有工作进程都相同，并提供有关训练集群的信息，这是一个由不同类型的作业组成的字典，例如 `'worker'` 。在使用 `MultiWorkerMirroredStrategy` 进行的多工作进程训练中，除了普通的 `'worker'` 之外，通常还有一个 `'worker'` 承担更多的责任，例如保存检查点和为 TensorBoard 编写摘要文件。这样的工作进程被称为 `'chief'` 工作进程，习惯上将 `'index'` 为 0 的 `'worker'` 指定为首席 `worker`。\n",
    "\n",
    "- `'task'` 提供当前任务的信息，并且在每个工作进程上都不相同。它指定该工作进程的 `'type'` 和 `'index'`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8YFpxrcsZ2xG"
   },
   "source": [
    "在本例中，您会将任务 `'type'` 设置为 `'worker'`，将任务 `'index'` 设置为 `0`。这台计算机是首个工作进程，将被指定为首席工作进程，并需要比其他工作进程承担更多的工作。请注意，其他计算机也需要设置 `'TF_CONFIG'` 环境变量，且应该具有相同的 `'cluster'` 字典，但要根据这些计算机的具体角色来设置不同的任务 `'type'` 或任务 `'index'`。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aogb74kHxynz"
   },
   "source": [
    "出于演示的目的，本教程将展示如何在 `'localhost'` 上设置具有两个工作进程的 `'TF_CONFIG'`。在实践中，用户会在外部 IP 地址/端口上创建多个工作进程，并为每个工作进程正确设置 `'TF_CONFIG'`。\n",
    "\n",
    "本示例使用两个工作进程，第一个工作进程的 `'TF_CONFIG'` 如上所示。对于第二个工作进程，设置 `tf_config['task']['index']=1`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cIlkfWmjz1PG"
   },
   "source": [
    "### 笔记本中的环境变量和子进程"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FcjAbuGY1ACJ"
   },
   "source": [
    "子进程会从其父进程继承环境变量。因此，如果您在此 Jupyter Notebook 进程中设置环境变量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:19.141656Z",
     "iopub.status.busy": "2023-11-07T23:15:19.141385Z",
     "iopub.status.idle": "2023-11-07T23:15:19.144932Z",
     "shell.execute_reply": "2023-11-07T23:15:19.144267Z"
    },
    "id": "PH2gHn2_0_U8"
   },
   "outputs": [],
   "source": [
    "os.environ['GREETINGS'] = 'Hello TensorFlow!'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gQkIX-cg18md"
   },
   "source": [
    "然后，您可以从子进程访问环境变量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:19.148549Z",
     "iopub.status.busy": "2023-11-07T23:15:19.148286Z",
     "iopub.status.idle": "2023-11-07T23:15:19.194941Z",
     "shell.execute_reply": "2023-11-07T23:15:19.194072Z"
    },
    "id": "pquKO6IA18G5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello TensorFlow!\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "echo ${GREETINGS}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "af6BCA-Y2fpz"
   },
   "source": [
    "在下一部分中，您将使用它来将 `'TF_CONFIG'` 传递给工作进程子进程。实际上，您永远不会以这种方式启动您的作业，但这完全可以满足此教程的演示目的：呈现最简单的多工作进程示例。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UhNtHfuxCGVy"
   },
   "source": [
    "## MultiWorkerMirroredStrategy\n",
    "\n",
    "在训练模型之前，首先创建一个 `tf.distribute.MultiWorkerMirroredStrategy` 的实例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:19.199657Z",
     "iopub.status.busy": "2023-11-07T23:15:19.198947Z",
     "iopub.status.idle": "2023-11-07T23:15:19.331499Z",
     "shell.execute_reply": "2023-11-07T23:15:19.330736Z"
    },
    "id": "1uFSHCJXMrQ-"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:19.298847: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n"
     ]
    }
   ],
   "source": [
    "strategy = tf.distribute.MultiWorkerMirroredStrategy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N0iv7SyyAohc"
   },
   "source": [
    "注：在您调用 `tf.distribute.MultiWorkerMirroredStrategy` 时，会解析 `'TF_CONFIG'` 并启动 TensorFlow 的 GRPC 服务器。因此，您必须在实例化 `tf.distribute.Strategy` 之前设置 `'TF_CONFIG'` 环境变量。为了在这个说明性示例中节省时间，本教程中没有对此进行演示，因此不需要启动服务器。您可以在本教程的最后一个部分中找到完整的示例。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TS4S-faBHHam"
   },
   "source": [
    "使用 `tf.distribute.Strategy.scope` 指定构建模型时应使用的策略。这使得该策略可以控制变量放置之类的事情，它将在所有工作进程的每个设备上，在模型的层中创建所有变量的副本。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:19.335070Z",
     "iopub.status.busy": "2023-11-07T23:15:19.334808Z",
     "iopub.status.idle": "2023-11-07T23:15:19.441785Z",
     "shell.execute_reply": "2023-11-07T23:15:19.441061Z"
    },
    "id": "nXV49tG1_opc"
   },
   "outputs": [],
   "source": [
    "import mnist\n",
    "with strategy.scope():\n",
    "  # Model building needs to be within `strategy.scope()`.\n",
    "  multi_worker_model = mnist.build_cnn_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DSYkM-on6r3Y"
   },
   "source": [
    "## 在工作进程之间对数据进行自动分片\n",
    "\n",
    "在多工作进程训练中，需要通过*数据集分片*来确保收敛性和可重复性。分片意味着将整个数据集的一个子集交给每个工作进程，这有助于创造类似于对单个工作进程进行训练的体验。在下面的示例中，您依赖于 `tf.distribute` 的默认自动分片策略。您还可以通过设置 `tf.data.experimental.DistributeOptions` 的 `tf.data.experimental.AutoShardPolicy` 来对其进行自定义。要了解更多信息，请参阅[分布式输入教程](input.ipynb)的*分片*部分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:19.446767Z",
     "iopub.status.busy": "2023-11-07T23:15:19.446050Z",
     "iopub.status.idle": "2023-11-07T23:15:20.082698Z",
     "shell.execute_reply": "2023-11-07T23:15:20.081648Z"
    },
    "id": "65-p36pt6rUF"
   },
   "outputs": [],
   "source": [
    "per_worker_batch_size = 64\n",
    "num_workers = len(tf_config['cluster']['worker'])\n",
    "global_batch_size = per_worker_batch_size * num_workers\n",
    "\n",
    "with strategy.scope():\n",
    "  multi_worker_dataset = strategy.distribute_datasets_from_function(\n",
    "      lambda input_context: mnist.dataset_fn(global_batch_size, input_context))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rkNzSR3g60iP"
   },
   "source": [
    "## 定义自定义训练循环并训练模型\n",
    "\n",
    "指定优化器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:20.087270Z",
     "iopub.status.busy": "2023-11-07T23:15:20.086938Z",
     "iopub.status.idle": "2023-11-07T23:15:20.100395Z",
     "shell.execute_reply": "2023-11-07T23:15:20.099700Z"
    },
    "id": "NoMr4_zTeKSn"
   },
   "outputs": [],
   "source": [
    "with strategy.scope():\n",
    "  # The creation of optimizer and train_accuracy needs to be in\n",
    "  # `strategy.scope()` as well, since they create variables.\n",
    "  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)\n",
    "  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
    "      name='train_accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RmrDcAii4B5O"
   },
   "source": [
    "使用 `tf.function` 定义训练步骤：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:20.104187Z",
     "iopub.status.busy": "2023-11-07T23:15:20.103914Z",
     "iopub.status.idle": "2023-11-07T23:15:20.110790Z",
     "shell.execute_reply": "2023-11-07T23:15:20.110133Z"
    },
    "id": "znXWN5S3eUDB"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(iterator):\n",
    "  \"\"\"Training step function.\"\"\"\n",
    "\n",
    "  def step_fn(inputs):\n",
    "    \"\"\"Per-Replica step function.\"\"\"\n",
    "    x, y = inputs\n",
    "    with tf.GradientTape() as tape:\n",
    "      predictions = multi_worker_model(x, training=True)\n",
    "      per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(\n",
    "          from_logits=True,\n",
    "          reduction=tf.keras.losses.Reduction.NONE)(y, predictions)\n",
    "      loss = tf.nn.compute_average_loss(\n",
    "          per_batch_loss, global_batch_size=global_batch_size)\n",
    "\n",
    "    grads = tape.gradient(loss, multi_worker_model.trainable_variables)\n",
    "    optimizer.apply_gradients(\n",
    "        zip(grads, multi_worker_model.trainable_variables))\n",
    "    train_accuracy.update_state(y, predictions)\n",
    "    return loss\n",
    "\n",
    "  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))\n",
    "  return strategy.reduce(\n",
    "      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eFXHsUVBy0Rx"
   },
   "source": [
    "### 检查点保存和恢复\n",
    "\n",
    "在编写自定义训练循环时，您需要手动处理[检查点保存](../../guide/checkpoint.ipynb)，而不是依赖 Keras 回调。请注意，对于 `MultiWorkerMirroredStrategy`，保存检查点或完整模型需要所有工作进程的参与，因为尝试仅在首席工作进程上进行保存可能会导致死锁。工作进程还需要写入不同的路径以避免相互重写。以下是如何配置目录的示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:20.114623Z",
     "iopub.status.busy": "2023-11-07T23:15:20.113953Z",
     "iopub.status.idle": "2023-11-07T23:15:20.120496Z",
     "shell.execute_reply": "2023-11-07T23:15:20.119689Z"
    },
    "id": "LcFO6x1KyjhI"
   },
   "outputs": [],
   "source": [
    "from multiprocessing import util\n",
    "checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')\n",
    "\n",
    "def _is_chief(task_type, task_id, cluster_spec):\n",
    "  return (task_type is None\n",
    "          or task_type == 'chief'\n",
    "          or (task_type == 'worker'\n",
    "              and task_id == 0\n",
    "              and \"chief\" not in cluster_spec.as_dict()))\n",
    "\n",
    "def _get_temp_dir(dirpath, task_id):\n",
    "  base_dirpath = 'workertemp_' + str(task_id)\n",
    "  temp_dir = os.path.join(dirpath, base_dirpath)\n",
    "  tf.io.gfile.makedirs(temp_dir)\n",
    "  return temp_dir\n",
    "\n",
    "def write_filepath(filepath, task_type, task_id, cluster_spec):\n",
    "  dirpath = os.path.dirname(filepath)\n",
    "  base = os.path.basename(filepath)\n",
    "  if not _is_chief(task_type, task_id, cluster_spec):\n",
    "    dirpath = _get_temp_dir(dirpath, task_id)\n",
    "  return os.path.join(dirpath, base)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nrcdPHtG4ObO"
   },
   "source": [
    "创建一个跟踪模型的 `tf.train.Checkpoint`，由 `tf.train.CheckpointManager` 管理，以便仅保留最新的检查点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:20.124367Z",
     "iopub.status.busy": "2023-11-07T23:15:20.123804Z",
     "iopub.status.idle": "2023-11-07T23:15:20.133425Z",
     "shell.execute_reply": "2023-11-07T23:15:20.132707Z"
    },
    "id": "4rURT2pI4aqV"
   },
   "outputs": [],
   "source": [
    "epoch = tf.Variable(\n",
    "    initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')\n",
    "step_in_epoch = tf.Variable(\n",
    "    initial_value=tf.constant(0, dtype=tf.dtypes.int64),\n",
    "    name='step_in_epoch')\n",
    "task_type, task_id = (strategy.cluster_resolver.task_type,\n",
    "                      strategy.cluster_resolver.task_id)\n",
    "# Normally, you don't need to manually instantiate a `ClusterSpec`, but in this \n",
    "# illustrative example you did not set `'TF_CONFIG'` before initializing the\n",
    "# strategy. Check out the next section for \"real-world\" usage.\n",
    "cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])\n",
    "\n",
    "checkpoint = tf.train.Checkpoint(\n",
    "    model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)\n",
    "\n",
    "write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,\n",
    "                                      cluster_spec)\n",
    "checkpoint_manager = tf.train.CheckpointManager(\n",
    "    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RO7cbN40XD5v"
   },
   "source": [
    "现在，当需要恢复检查点时，您可以方便地使用 `tf.train.latest_checkpoint` 函数（或通过调用 `tf.train.CheckpointManager.restore_or_initialize` ）找到最新的已保存检查点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:20.137411Z",
     "iopub.status.busy": "2023-11-07T23:15:20.136897Z",
     "iopub.status.idle": "2023-11-07T23:15:20.140692Z",
     "shell.execute_reply": "2023-11-07T23:15:20.140009Z"
    },
    "id": "gniynaQj6HMV"
   },
   "outputs": [],
   "source": [
    "latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n",
    "if latest_checkpoint:\n",
    "  checkpoint.restore(latest_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1j9JuI-h6ObW"
   },
   "source": [
    "恢复检查点后，您可以继续训练自定义训练循环。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:20.143940Z",
     "iopub.status.busy": "2023-11-07T23:15:20.143513Z",
     "iopub.status.idle": "2023-11-07T23:15:24.222775Z",
     "shell.execute_reply": "2023-11-07T23:15:24.221984Z"
    },
    "id": "kZzXZCh45FY6"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:20.366756: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, accuracy: 0.807366, train_loss: 0.621664.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1, accuracy: 0.926786, train_loss: 0.255375.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2, accuracy: 0.947656, train_loss: 0.172921.\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 3\n",
    "num_steps_per_epoch = 70\n",
    "\n",
    "while epoch.numpy() < num_epochs:\n",
    "  iterator = iter(multi_worker_dataset)\n",
    "  total_loss = 0.0\n",
    "  num_batches = 0\n",
    "\n",
    "  while step_in_epoch.numpy() < num_steps_per_epoch:\n",
    "    total_loss += train_step(iterator)\n",
    "    num_batches += 1\n",
    "    step_in_epoch.assign_add(1)\n",
    "\n",
    "  train_loss = total_loss / num_batches\n",
    "  print('Epoch: %d, accuracy: %f, train_loss: %f.'\n",
    "                %(epoch.numpy(), train_accuracy.result(), train_loss))\n",
    "\n",
    "  train_accuracy.reset_states()\n",
    "\n",
    "  # Once the `CheckpointManager` is set up, you're now ready to save, and remove\n",
    "  # the checkpoints non-chief workers saved.\n",
    "  checkpoint_manager.save()\n",
    "  if not _is_chief(task_type, task_id, cluster_spec):\n",
    "    tf.io.gfile.rmtree(write_checkpoint_dir)\n",
    "\n",
    "  epoch.assign_add(1)\n",
    "  step_in_epoch.assign(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0W1Osks466DE"
   },
   "source": [
    "## 完整代码一览"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jfYpmIxO6Jck"
   },
   "source": [
    "总结一下到目前为止讨论的所有程序：\n",
    "\n",
    "1. 创建工作进程。\n",
    "2. 将 `'TF_CONFIG'` 传递给工作进程。\n",
    "3. 让每个工作进程运行下面包含训练代码的脚本。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:24.227041Z",
     "iopub.status.busy": "2023-11-07T23:15:24.226440Z",
     "iopub.status.idle": "2023-11-07T23:15:24.233423Z",
     "shell.execute_reply": "2023-11-07T23:15:24.232771Z"
    },
    "id": "MIDCESkVzN6M"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing main.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile main.py\n",
    "#@title File: `main.py`\n",
    "import os\n",
    "import json\n",
    "import tensorflow as tf\n",
    "import mnist\n",
    "from multiprocessing import util\n",
    "\n",
    "per_worker_batch_size = 64\n",
    "tf_config = json.loads(os.environ['TF_CONFIG'])\n",
    "num_workers = len(tf_config['cluster']['worker'])\n",
    "global_batch_size = per_worker_batch_size * num_workers\n",
    "\n",
    "num_epochs = 3\n",
    "num_steps_per_epoch=70\n",
    "\n",
    "# Checkpoint saving and restoring\n",
    "def _is_chief(task_type, task_id, cluster_spec):\n",
    "  return (task_type is None\n",
    "          or task_type == 'chief'\n",
    "          or (task_type == 'worker'\n",
    "              and task_id == 0\n",
    "              and 'chief' not in cluster_spec.as_dict()))\n",
    "    \n",
    "def _get_temp_dir(dirpath, task_id):\n",
    "  base_dirpath = 'workertemp_' + str(task_id)\n",
    "  temp_dir = os.path.join(dirpath, base_dirpath)\n",
    "  tf.io.gfile.makedirs(temp_dir)\n",
    "  return temp_dir\n",
    "\n",
    "def write_filepath(filepath, task_type, task_id, cluster_spec):\n",
    "  dirpath = os.path.dirname(filepath)\n",
    "  base = os.path.basename(filepath)\n",
    "  if not _is_chief(task_type, task_id, cluster_spec):\n",
    "    dirpath = _get_temp_dir(dirpath, task_id)\n",
    "  return os.path.join(dirpath, base)\n",
    "\n",
    "checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')\n",
    "\n",
    "# Define Strategy\n",
    "strategy = tf.distribute.MultiWorkerMirroredStrategy()\n",
    "\n",
    "with strategy.scope():\n",
    "  # Model building/compiling need to be within `tf.distribute.Strategy.scope`.\n",
    "  multi_worker_model = mnist.build_cnn_model()\n",
    "\n",
    "  multi_worker_dataset = strategy.distribute_datasets_from_function(\n",
    "      lambda input_context: mnist.dataset_fn(global_batch_size, input_context))        \n",
    "  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)\n",
    "  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
    "      name='train_accuracy')\n",
    "\n",
    "@tf.function\n",
    "def train_step(iterator):\n",
    "  \"\"\"Training step function.\"\"\"\n",
    "\n",
    "  def step_fn(inputs):\n",
    "    \"\"\"Per-Replica step function.\"\"\"\n",
    "    x, y = inputs\n",
    "    with tf.GradientTape() as tape:\n",
    "      predictions = multi_worker_model(x, training=True)\n",
    "      per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(\n",
    "          from_logits=True,\n",
    "          reduction=tf.keras.losses.Reduction.NONE)(y, predictions)\n",
    "      loss = tf.nn.compute_average_loss(\n",
    "          per_batch_loss, global_batch_size=global_batch_size)\n",
    "\n",
    "    grads = tape.gradient(loss, multi_worker_model.trainable_variables)\n",
    "    optimizer.apply_gradients(\n",
    "        zip(grads, multi_worker_model.trainable_variables))\n",
    "    train_accuracy.update_state(y, predictions)\n",
    "\n",
    "    return loss\n",
    "\n",
    "  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))\n",
    "  return strategy.reduce(\n",
    "      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n",
    "\n",
    "epoch = tf.Variable(\n",
    "    initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')\n",
    "step_in_epoch = tf.Variable(\n",
    "    initial_value=tf.constant(0, dtype=tf.dtypes.int64),\n",
    "    name='step_in_epoch')\n",
    "\n",
    "task_type, task_id, cluster_spec = (strategy.cluster_resolver.task_type,\n",
    "                                    strategy.cluster_resolver.task_id,\n",
    "                                    strategy.cluster_resolver.cluster_spec())\n",
    "\n",
    "checkpoint = tf.train.Checkpoint(\n",
    "    model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)\n",
    "\n",
    "write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,\n",
    "                                      cluster_spec)\n",
    "checkpoint_manager = tf.train.CheckpointManager(\n",
    "    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)\n",
    "\n",
    "# Restoring the checkpoint\n",
    "latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n",
    "if latest_checkpoint:\n",
    "  checkpoint.restore(latest_checkpoint)\n",
    "\n",
    "# Resume our CTL training\n",
    "while epoch.numpy() < num_epochs:\n",
    "  iterator = iter(multi_worker_dataset)\n",
    "  total_loss = 0.0\n",
    "  num_batches = 0\n",
    "\n",
    "  while step_in_epoch.numpy() < num_steps_per_epoch:\n",
    "    total_loss += train_step(iterator)\n",
    "    num_batches += 1\n",
    "    step_in_epoch.assign_add(1)\n",
    "\n",
    "  train_loss = total_loss / num_batches\n",
    "  print('Epoch: %d, accuracy: %f, train_loss: %f.'\n",
    "                %(epoch.numpy(), train_accuracy.result(), train_loss))\n",
    "  \n",
    "  train_accuracy.reset_states()\n",
    "\n",
    "  checkpoint_manager.save()\n",
    "  if not _is_chief(task_type, task_id, cluster_spec):\n",
    "    tf.io.gfile.rmtree(write_checkpoint_dir)\n",
    "\n",
    "  epoch.assign_add(1)\n",
    "  step_in_epoch.assign(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ItVOvPN1qnZ6"
   },
   "source": [
    "当前目录现包含两个 Python 文件："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:24.237059Z",
     "iopub.status.busy": "2023-11-07T23:15:24.236529Z",
     "iopub.status.idle": "2023-11-07T23:15:24.300080Z",
     "shell.execute_reply": "2023-11-07T23:15:24.299121Z"
    },
    "id": "bi6x05Sr60O9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "main.py\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist.py\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "ls *.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qmEEStPS6vR_"
   },
   "source": [
    "因此，对 `'TF_CONFIG'` 执行 JSON 序列化，然后将其添加到环境变量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:24.304381Z",
     "iopub.status.busy": "2023-11-07T23:15:24.303636Z",
     "iopub.status.idle": "2023-11-07T23:15:24.307915Z",
     "shell.execute_reply": "2023-11-07T23:15:24.307209Z"
    },
    "id": "9uu3g7vV7Bbt"
   },
   "outputs": [],
   "source": [
    "os.environ['TF_CONFIG'] = json.dumps(tf_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MsY3dQLK7jdf"
   },
   "source": [
    "现在，您可以启动一个将运行 `main.py` 并使用 `'TF_CONFIG'` 的工作进程："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:24.311524Z",
     "iopub.status.busy": "2023-11-07T23:15:24.311017Z",
     "iopub.status.idle": "2023-11-07T23:15:24.315129Z",
     "shell.execute_reply": "2023-11-07T23:15:24.314498Z"
    },
    "id": "txMXaq8d8N_S"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All background processes were killed.\n"
     ]
    }
   ],
   "source": [
    "# first kill any previous runs\n",
    "%killbgscripts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:24.318373Z",
     "iopub.status.busy": "2023-11-07T23:15:24.318092Z",
     "iopub.status.idle": "2023-11-07T23:15:24.375371Z",
     "shell.execute_reply": "2023-11-07T23:15:24.374149Z"
    },
    "id": "qnSma_Ck7r-r"
   },
   "outputs": [],
   "source": [
    "%%bash --bg\n",
    "python main.py &> job_0.log"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZChyazqS7v0P"
   },
   "source": [
    "以上命令有几点需要注意：\n",
    "\n",
    "1. 它使用 `%%bash`，这是一项用于运行一些 bash 命令的[笔记本“魔术命令”](https://ipython.readthedocs.io/en/stable/interactive/magics.html)。\n",
    "2. 它使用 `--bg` 标志在后台运行 `bash` 进程，因为此工作进程不会终止。它在开始之前会等待所有工作进程。\n",
    "\n",
    "后台工作进程不会将输出打印到此笔记本。`&>` 会将其输出重定向到一个文件，以便您可以查看所发生的情况。\n",
    "\n",
    "等待几秒钟以启动该进程："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:24.379842Z",
     "iopub.status.busy": "2023-11-07T23:15:24.379527Z",
     "iopub.status.idle": "2023-11-07T23:15:44.404224Z",
     "shell.execute_reply": "2023-11-07T23:15:44.403201Z"
    },
    "id": "Hm2yrULE9281"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "time.sleep(20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZFPoNxg_9_Mx"
   },
   "source": [
    "接下来，检查一下目前为止输出到工作进程日志文件的内容："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:44.409197Z",
     "iopub.status.busy": "2023-11-07T23:15:44.408493Z",
     "iopub.status.idle": "2023-11-07T23:15:44.474674Z",
     "shell.execute_reply": "2023-11-07T23:15:44.473733Z"
    },
    "id": "vZEOuVgQ9-hn"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:24.897952: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:24.898016: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:24.899709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:27.043487: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "cat job_0.log"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RqZhVF7L_KOy"
   },
   "source": [
    "日志文件的最后一行内容应为：`Started server with target: grpc://localhost:12345`。第一个工作进程现已准备就绪，正在等待所有其他工作进程准备就绪以继续。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pi8vPNNA_l4a"
   },
   "source": [
    "更新 `tf_config` 以供第二个工作进程取用："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:44.478977Z",
     "iopub.status.busy": "2023-11-07T23:15:44.478379Z",
     "iopub.status.idle": "2023-11-07T23:15:44.483079Z",
     "shell.execute_reply": "2023-11-07T23:15:44.482298Z"
    },
    "id": "lAiYkkPu_Jqd"
   },
   "outputs": [],
   "source": [
    "tf_config['task']['index'] = 1\n",
    "os.environ['TF_CONFIG'] = json.dumps(tf_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0AshGVO0_x0w"
   },
   "source": [
    "现在，启动第二个工作进程。这将开始训练，因为所有工作进程都已处于活动状态（因此无需在后台执行此进程）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:44.486857Z",
     "iopub.status.busy": "2023-11-07T23:15:44.486284Z",
     "iopub.status.idle": "2023-11-07T23:15:59.434365Z",
     "shell.execute_reply": "2023-11-07T23:15:59.433017Z"
    },
    "id": "_ESVtyQ9_xjx"
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "python main.py > /dev/null 2>&1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hX4FA2O2AuAn"
   },
   "source": [
    "如果您重新检查第一个工作进程编写的日志，您会看到它参与了该模型的训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:59.441138Z",
     "iopub.status.busy": "2023-11-07T23:15:59.440438Z",
     "iopub.status.idle": "2023-11-07T23:15:59.507207Z",
     "shell.execute_reply": "2023-11-07T23:15:59.506105Z"
    },
    "id": "rc6hw3yTBKXX"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:24.897952: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:24.898016: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:24.899709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:27.043487: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:15:48.287770: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, accuracy: 0.804129, train_loss: 0.624825.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1, accuracy: 0.920201, train_loss: 0.276320.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2, accuracy: 0.946429, train_loss: 0.194815.\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "cat job_0.log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:15:59.511709Z",
     "iopub.status.busy": "2023-11-07T23:15:59.511401Z",
     "iopub.status.idle": "2023-11-07T23:15:59.516717Z",
     "shell.execute_reply": "2023-11-07T23:15:59.516003Z"
    },
    "id": "sG5_1UgrgniF"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All background processes were killed.\n"
     ]
    }
   ],
   "source": [
    "# Delete the `'TF_CONFIG'`, and kill any background tasks so they don't affect the next section.\n",
    "os.environ.pop('TF_CONFIG', None)\n",
    "%killbgscripts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bhxMXa0AaZkK"
   },
   "source": [
    "## 深入了解多工作进程训练\n",
    "\n",
    "本教程演示了多工作进程设置的自定义训练循环工作流程。有关其他主题的详细描述可在适用于自定义训练循环的[使用 Keras 进行多工作进程训练 (`tf.keras.Model.fit`)](multi_worker_with_keras.ipynb) 教程中找到。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ega2hdOQEmy_"
   },
   "source": [
    "## 了解更多\n",
    "\n",
    "1. [TensorFlow 中的分布式训练](../../guide/distributed_training.ipynb)指南概述了可用的分布式策略。\n",
    "2. [官方模型](https://github.com/tensorflow/models/tree/master/official)，其中许多模型可以配置为运行多个分布式策略。\n",
    "3. `tf.function` 指南中的[“性能”部分](../../guide/function.ipynb)提供了有关其他策略和[工具](../../guide/profiler.md)的信息，您可以使用它们来优化 TensorFlow 模型的性能。\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "multi_worker_with_ctl.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
